from collections import OrderedDict
from copy import deepcopy

from softlearning.preprocessors.utils import get_preprocessor_from_params

from collections import OrderedDict
from copy import deepcopy

from softlearning.preprocessors.utils import get_preprocessor_from_params
from .base_critic import FeedforwardCritic


def get_feedforward_ciritc(*args, **kwargs):
    from .base_critic import FeedforwardCritic

    critic = FeedforwardCritic(*args, **kwargs)

    return critic

def get_critic(critic_type, *args, **kwargs):
    return CRITIC_FUNCTIONS[critic_type](*args, **kwargs)

CRITIC_FUNCTIONS = {
    'FeedforwardCritic': get_feedforward_ciritc
}

def get_critic_from_variant(variant, *args, **kwargs):
    critic_params = variant['critic_params']
    return get_critic_from_params(critic_params, *args, **kwargs)

def get_extractor_from_variant(variant, *args, **kwargs):
    critic_params = variant['extractor_params']
    return get_critic_from_params(critic_params, *args, **kwargs)

def get_critic_from_params(critic_params, env, *args, **kwargs):
    # critic_params = deepcopy(critic_params['critic_params'])
    critic_type = deepcopy(critic_params['type'])
    critic_kwargs = deepcopy(critic_params.get('kwargs', {}))

    observation_preprocessors_params = critic_kwargs.pop(
        'observation_preprocessors_params', {}).copy()
    observation_keys = critic_kwargs.pop(
        'observation_keys', None) or env.observation_keys

    observation_shapes = OrderedDict((
        (key, value) for key, value in env.observation_shape.items()
        if key in observation_keys
    ))
    action_shape = env.action_shape
    input_shapes = {
        'observations': observation_shapes,
        'actions': action_shape,
    }

    observation_preprocessors = OrderedDict()
    for name, observation_shape in observation_shapes.items():
        preprocessor_params = observation_preprocessors_params.get(name, None)
        if not preprocessor_params:
            observation_preprocessors[name] = None
            continue
        observation_preprocessors[name] = get_preprocessor_from_params(
            env, preprocessor_params)

    action_preprocessor = None
    preprocessors = {
        'observations': observation_preprocessors,
        'actions': action_preprocessor,
    }

    critic_function = CRITIC_FUNCTIONS[critic_type](
        input_shapes=input_shapes,
        observation_keys=observation_keys,
        preprocessors=preprocessors,
        *args,
        **critic_kwargs,
        **kwargs)

    return critic_function





